{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ec96a87b-866e-4eab-b46b-246941325bbd",
   "metadata": {},
   "outputs": [],
   "source": [
    "''' Training '''\n",
    "print(\"Training\")\n",
    "print(\"...\")\n",
    "# 使用training set訓練，並使用validation set尋找好的參數\n",
    "model = Classifier().cuda()\n",
    "loss = nn.CrossEntropyLoss()  # 因為是 classification task，所以 loss 使用 CrossEntropyLoss\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=0.001)  # optimizer 使用 Adam\n",
    "num_epoch = 30  # 迭代30次\n",
    "\n",
    "for epoch in range(num_epoch):\n",
    "    epoch_start_time = time.time()\n",
    "    train_acc = 0.0\n",
    "    train_loss = 0.0\n",
    "    val_acc = 0.0\n",
    "    val_loss = 0.0\n",
    "\n",
    "    model.train()  # 確保 model 是在 train model (開啟 Dropout 等...)\n",
    "    for i, data in enumerate(train_loader):\n",
    "        optimizer.zero_grad()  # 用 optimizer 將 model 參數的 gradient 歸零\n",
    "        train_pred = model(data[0].cuda())  # 利用 model 得到預測的機率分佈 這邊實際上就是去呼叫 model 的 forward 函數\n",
    "        batch_loss = loss(train_pred, data[1].cuda())  # 計算 loss （注意 prediction 跟 label 必須同時在 CPU 或是 GPU 上）\n",
    "        batch_loss.backward()  # 利用 back propagation 算出每個參數的 gradient\n",
    "        optimizer.step()  # 以 optimizer 用 gradient 更新參數值\n",
    "\n",
    "        train_acc += np.sum(np.argmax(train_pred.cpu().data.numpy(), axis=1) == data[1].numpy())\n",
    "        train_loss += batch_loss.item()\n",
    "\n",
    "    model.eval()\n",
    "    with torch.no_grad():\n",
    "        for i, data in enumerate(val_loader):\n",
    "            val_pred = model(data[0].cuda())\n",
    "            batch_loss = loss(val_pred, data[1].cuda())\n",
    "\n",
    "            val_acc += np.sum(np.argmax(val_pred.cpu().data.numpy(), axis=1) == data[1].numpy())\n",
    "            val_loss += batch_loss.item()\n",
    "\n",
    "        # 將結果 print 出來\n",
    "        print('[%03d/%03d] %2.2f sec(s) Train Acc: %3.6f Loss: %3.6f | Val Acc: %3.6f loss: %3.6f' % \\\n",
    "              (epoch + 1, num_epoch, time.time() - epoch_start_time, \\\n",
    "               train_acc / train_set.__len__(), train_loss / train_set.__len__(), val_acc / val_set.__len__(),\n",
    "               val_loss / val_set.__len__()))\n",
    "\n",
    "train_val_x = np.concatenate((train_x, val_x), axis=0)\n",
    "train_val_y = np.concatenate((train_y, val_y), axis=0)\n",
    "train_val_set = ImgDataset(train_val_x, train_val_y, train_transform)\n",
    "train_val_loader = DataLoader(train_val_set, batch_size=batch_size, shuffle=True)\n",
    "\n",
    "model_best = Classifier().cuda()\n",
    "loss = nn.CrossEntropyLoss()  # 因為是 classification task，所以 loss 使用 CrossEntropyLoss\n",
    "optimizer = torch.optim.Adam(model_best.parameters(), lr=0.001)  # optimizer 使用 Adam\n",
    "num_epoch = 30\n",
    "\n",
    "for epoch in range(num_epoch):\n",
    "    epoch_start_time = time.time()\n",
    "    train_acc = 0.0\n",
    "    train_loss = 0.0\n",
    "\n",
    "    model_best.train()\n",
    "    for i, data in enumerate(train_val_loader):\n",
    "        optimizer.zero_grad()\n",
    "        train_pred = model_best(data[0].cuda())\n",
    "        batch_loss = loss(train_pred, data[1].cuda())\n",
    "        batch_loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        train_acc += np.sum(np.argmax(train_pred.cpu().data.numpy(), axis=1) == data[1].numpy())\n",
    "        train_loss += batch_loss.item()\n",
    "\n",
    "        # 將結果 print 出來\n",
    "    print('[%03d/%03d] %2.2f sec(s) Train Acc: %3.6f Loss: %3.6f' % \\\n",
    "          (epoch + 1, num_epoch, time.time() - epoch_start_time, \\\n",
    "           train_acc / train_val_set.__len__(), train_loss / train_val_set.__len__()))\n",
    "\n",
    "print(\"Training complicated\")\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:base] *",
   "language": "python",
   "name": "conda-base-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
